{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Components\n",
    "\n",
    "The main components of a `WideDeep` (i.e. Multimodal) model are tabular data, text and images, which are feed into the model via so called `wide`, `deeptabular`, `deeptext` and `deepimage` model components"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. `wide`\n",
    "\n",
    "The `wide` component is a Linear layer \"plugged\" into the output neuron(s). Here, the non-linearities are captured via crossed columns. Crossed columns are, quoting directly the paper: \"*For binary features, a cross-product transformation (e.g., “AND(gender=female, language=en)”) is 1 if and only if the constituent features (“gender=female” and “language=en”) are all 1, and 0 otherwise*\".\n",
    "\n",
    "The only particularity of our implementation is that we have implemented the linear layer via an Embedding layer plus a bias. While the implementations are equivalent, the latter is faster and far more memory efficient, since we do not need to one hot encode the categorical features. \n",
    "\n",
    "Let's assume we the following dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "from torch import nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>color</th>\n",
       "      <th>size</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>r</td>\n",
       "      <td>s</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>b</td>\n",
       "      <td>n</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>g</td>\n",
       "      <td>l</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  color size\n",
       "0     r    s\n",
       "1     b    n\n",
       "2     g    l"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.DataFrame({\"color\": [\"r\", \"b\", \"g\"], \"size\": [\"s\", \"n\", \"l\"]})\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "one hot encoded, the first observation would be"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_0_oh = (np.array([1.0, 0.0, 0.0, 1.0, 0.0, 0.0])).astype(\"float32\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "if we simply numerically encode (label encode or `le`) the values:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_0_le = (np.array([0, 3])).astype(\"int64\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that in the functioning implementation of the package we start from 1, saving 0 for padding, i.e. unseen values. \n",
    "\n",
    "Now, let's see if the two implementations are equivalent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we have 6 different values. Let's assume we are performing a regression, so pred_dim = 1\n",
    "lin = nn.Linear(6, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "emb = nn.Embedding(6, 1)\n",
    "emb.weight = nn.Parameter(lin.weight.reshape_as(emb.weight))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.5181], grad_fn=<ViewBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lin(torch.tensor(obs_0_oh))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.5181], grad_fn=<AddBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "emb(torch.tensor(obs_0_le)).sum() + lin.bias"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And this is precisely how the linear model `Wide` is implemented"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/.pyenv/versions/3.10.13/envs/widedeep310/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from pytorch_widedeep.models import Wide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ?Wide"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Wide(\n",
       "  (wide_linear): Embedding(11, 1, padding_idx=0)\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wide = Wide(input_dim=10, pred_dim=1)\n",
    "wide"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that even though the input dim is 10, the Embedding layer has 11 weights. Again, this is because we save `0` for padding, which is used for unseen values during the encoding process. \n",
    "\n",
    "As I mentioned, `deeptabular` has enough complexity on its own and it will be described in a separated notebook. Let's then jump to `deeptext`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. `deeptabular`\n",
    "\n",
    "The `deeptabular` model alone is what normally would be referred as Deep Learning for tabular data. As mentioned a number of times throughout the library, each component can be used independently. Therefore, if you wanted to use any of the models below alone, it is perfectly possible. There are just a couple of simple requirement that will be covered in a later notebook.\n",
    "\n",
    "By the time of writing, there are a number of models available in `pytorch-widedeep` to do DL for tabular data. These are:\n",
    "\n",
    "1. `TabMlp`\n",
    "2. `ContextAttentionMLP`\n",
    "3. `SelfAttentionMLP`\n",
    "4. `TabResnet`\n",
    "5. `Tabnet`\n",
    "6. `TabTransformer`\n",
    "7. `FT-Tabransformer`\n",
    "8. `SAINT`\n",
    "9. `TabFastFormer`\n",
    "10. `TabPerceiver`\n",
    "\n",
    "Let's have a look to one of them. For more information on each of these models, please, have a look to the documentation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_widedeep.preprocessing import TabPreprocessor\n",
    "from pytorch_widedeep.models import TabMlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = {\n",
    "    \"cat1\": np.random.choice([\"A\", \"B\", \"C\"], size=20),\n",
    "    \"cat2\": np.random.choice([\"X\", \"Y\"], size=20),\n",
    "    \"cont1\": np.random.rand(20),\n",
    "    \"cont2\": np.random.rand(20),\n",
    "}\n",
    "\n",
    "df = pd.DataFrame(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>cat1</th>\n",
       "      <th>cat2</th>\n",
       "      <th>cont1</th>\n",
       "      <th>cont2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>A</td>\n",
       "      <td>Y</td>\n",
       "      <td>0.789347</td>\n",
       "      <td>0.561789</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C</td>\n",
       "      <td>X</td>\n",
       "      <td>0.050822</td>\n",
       "      <td>0.061538</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>A</td>\n",
       "      <td>Y</td>\n",
       "      <td>0.863784</td>\n",
       "      <td>0.241967</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>C</td>\n",
       "      <td>X</td>\n",
       "      <td>0.917848</td>\n",
       "      <td>0.644658</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>C</td>\n",
       "      <td>Y</td>\n",
       "      <td>0.042328</td>\n",
       "      <td>0.417303</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  cat1 cat2     cont1     cont2\n",
       "0    A    Y  0.789347  0.561789\n",
       "1    C    X  0.050822  0.061538\n",
       "2    A    Y  0.863784  0.241967\n",
       "3    C    X  0.917848  0.644658\n",
       "4    C    Y  0.042328  0.417303"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# see the docs for details on all params/options\n",
    "tab_preprocessor = TabPreprocessor(\n",
    "    cat_embed_cols=[\"cat1\", \"cat2\"],\n",
    "    continuous_cols=[\"cont1\", \"cont2\"],\n",
    "    embedding_rule=\"fastai\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/Projects/pytorch-widedeep/pytorch_widedeep/preprocessing/tab_preprocessor.py:358: UserWarning: Continuous columns will not be normalised\n",
      "  warnings.warn(\"Continuous columns will not be normalised\")\n"
     ]
    }
   ],
   "source": [
    "X_tab = tab_preprocessor.fit_transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TabMlp(\n",
       "  (cat_embed): DiffSizeCatEmbeddings(\n",
       "    (embed_layers): ModuleDict(\n",
       "      (emb_layer_cat1): Embedding(4, 3, padding_idx=0)\n",
       "      (emb_layer_cat2): Embedding(3, 2, padding_idx=0)\n",
       "    )\n",
       "    (embedding_dropout): Dropout(p=0.0, inplace=False)\n",
       "  )\n",
       "  (cont_norm): Identity()\n",
       "  (cont_embed): ContEmbeddings(\n",
       "    INFO: [ContLinear = weight(n_cont_cols, embed_dim) + bias(n_cont_cols, embed_dim)]\n",
       "    (linear): ContLinear(n_cont_cols=2, embed_dim=4, embed_dropout=0.0)\n",
       "    (dropout): Dropout(p=0.0, inplace=False)\n",
       "  )\n",
       "  (encoder): MLP(\n",
       "    (mlp): Sequential(\n",
       "      (dense_layer_0): Sequential(\n",
       "        (0): Linear(in_features=13, out_features=8, bias=True)\n",
       "        (1): ReLU(inplace=True)\n",
       "        (2): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "      (dense_layer_1): Sequential(\n",
       "        (0): Linear(in_features=8, out_features=4, bias=True)\n",
       "        (1): ReLU(inplace=True)\n",
       "        (2): Dropout(p=0.1, inplace=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# toy example just to build a model.\n",
    "tabmlp = TabMlp(\n",
    "    column_idx=tab_preprocessor.column_idx,\n",
    "    cat_embed_input=tab_preprocessor.cat_embed_input,\n",
    "    continuous_cols=tab_preprocessor.continuous_cols,\n",
    "    embed_continuous_method=\"standard\",\n",
    "    cont_embed_dim=4,\n",
    "    mlp_hidden_dims=[8, 4],\n",
    "    mlp_linear_first=True,\n",
    ")\n",
    "tabmlp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets describe a bit the model: first we have what we call a `DiffSizeCatEmbeddings`, where categorical columns with different number of unique categories will be encoded with embeddings of different dimensions. Then the continuous columns will not be normalised (the normalised layer is just the identity) and they will be embedded via a \"standard\" method, using a so-called `ContLinear` layer. This layer displays some `INFO` that tells us what it is (`ContLinear = weight(n_cont_cols, embed_dim) + bias(n_cont_cols, embed_dim)]`). There are two other options available to embed the continuous cols based on the paper [On Embeddings for Numerical Features in Tabular Deep Learning](https://arxiv.org/abs/2203.05556). These are `PieceWise` and `Periodic` and all available via the `embed_continuous_method` param, which can adopt values `\"standard\", \"piecewise\"` and `\"periodic\"`. The embedded categorical and continuous columns will be then concatenated ($3 + 2 + (4 * 2) = 13$ input dims) and passed to an MLP. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  3. `deeptext`\n",
    "\n",
    "At the time of writing, `pytorch-widedeep` offers three models that can be passed to `WideDeep` as the `deeptext` component. These are:\n",
    "\n",
    "1. BasicRNN\n",
    "2. AttentiveRNN\n",
    "3. StackedAttentiveRNN\n",
    "\n",
    "For details on each of these models, please, have a look to the documentation of the package. \n",
    "\n",
    "We will soon integrate with Hugginface, but let me insist. It is perfectly possible to use custom models for each component, please, have a look to the corresponding notebook. In general, simply, build them and pass them as the corresponding parameters. Note that the custom models MUST return a last layer of activations (i.e. not the final prediction) so that  these activations are collected by `WideDeep` and combined accordingly. In  addition, the models MUST also contain an attribute `output_dim` with the size of these last layers of activations.\n",
    "\n",
    "Let's have a look to the `BasicRNN` model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_widedeep.models import BasicRNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/javierrodriguezzaurin/.pyenv/versions/3.10.13/envs/widedeep310/lib/python3.10/site-packages/torch/nn/modules/rnn.py:82: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.1 and num_layers=1\n",
      "  warnings.warn(\"dropout option adds dropout after all but last \"\n"
     ]
    }
   ],
   "source": [
    "basic_rnn = BasicRNN(vocab_size=4, hidden_dim=4, n_layers=1, padding_idx=0, embed_dim=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BasicRNN(\n",
       "  (word_embed): Embedding(4, 4, padding_idx=0)\n",
       "  (rnn): LSTM(4, 4, batch_first=True, dropout=0.1)\n",
       "  (rnn_mlp): Identity()\n",
       ")"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "basic_rnn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You could, if you wanted, add a Fully Connected Head (FC-Head) on top of it"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  4. `deepimage`\n",
    "\n",
    "At the time of writing `pytorch-widedeep` is integrated with torchvision via the `Vision` class. This means that the it is possible to use a variant of the following architectures:\n",
    "\n",
    "\n",
    "1. resnet\n",
    "2. shufflenet\n",
    "3. resnext\n",
    "4. wide_resnet\n",
    "5. regnet\n",
    "6. densenet\n",
    "7. mobilenet\n",
    "8. mnasnet\n",
    "9. efficientnet\n",
    "10. squeezenet\n",
    "\n",
    "The user can choose which layers will be trainable. Alternatively, in none of these architectures is useful, one could use a simple, fully trained CNN (please see the package documentation) or pass a custom model. \n",
    "\n",
    "let's have a look"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_widedeep.models import Vision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet = Vision(pretrained_model_setup=\"resnet18\", n_trainable=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Vision(\n",
       "  (features): Sequential(\n",
       "    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (2): ReLU(inplace=True)\n",
       "    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "    (4): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (5): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (6): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (7): Sequential(\n",
       "      (0): BasicBlock(\n",
       "        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): BasicBlock(\n",
       "        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "resnet"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "3b99005fd577fa40f3cce433b2b92303885900e634b2b5344c07c59d06c8792d"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
